-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Fix save_hyperparameters
not crashing on dataclass
with init=False
#21051
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Fix save_hyperparameters
not crashing on dataclass
with init=False
#21051
Conversation
Thank you! I wonder if the warning is useful. I encountered #21036 by doing exactly what the warning suggests to do: by initializing the attributes in In that case, the warning would instruct the user to do exactly what they're already doing, which might be confusing and make them wonder what they should do exactly. WDYT? |
obj_fields = fields(obj) | ||
init_args = {f.name: getattr(obj, f.name) for f in obj_fields if f.init} | ||
if any(not f.init for f in obj_fields): | ||
rank_zero_warn( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See discussion above about this warning
Co-authored-by: Quentin Soubeyran <[email protected]>
@QuentinSoubeyranAqemia what I was thinking that users would need to do, would be something like this import dataclasses
import lightning.pytorch as L
@dataclasses.dataclass
class Module(L.LightningModule):
a: float
b: float
c: float = 0.0
def __post_init__(self):
self.c = self.a + self.b
self.save_hyperparameters()
model = Module(a=1, b=2)
print(model.hparams) this is kind of misuse of |
save_hyperparameters
not crashing on dataclass
with `init=Falsesave_hyperparameters
not crashing on dataclass
with init=False
Perhaps the minimal reproducing example I provided is causing some confusion. In practice, the attributes that needs to be excluded with The use-case is for internal Here's an example which is hopefully clearer: transforming a (simplistic) vanilla lightning module to leverage import lightning.pytorch as L
class Module(L.LightningModule):
input_dim: int
hparam: int
readout_dim: int
model: MyModule
readout: nn.Linear
# we need to write a boilerplate signature again, though it is already laid out above
def __init__(self, input_dim: int, hparam: int, readout_dim: int):
super().__init__()
self.save_hyperparameters()
# boilerplate to store hyper-parameters
sefl.input_dim = input_dim
self.hparam = hparam
self.readout_dim = readout_dim
# create internals
self.model = MyModule(self.input_dim, self.hparam) # nn.Module
self.readout = nn.Linear(..., self.readout_dim) into the more concise import lightning.pytorch as L
import dataclasses
@dataclasses.dataclass(...) # some specific flag needed here, not the point of this discussion
class Module(L.LightningModule):
input_dim: int
hidden_dim: Sequence[int]
readout_dim: int
# internals, not hparams, exclude them from __init__ and save_hyperparameters()
model: MyModule = dataclasses.field(init=False)
readout: nn.Linear = dataclasses.field(init=False)
def __post_init__(self):
# no boilerplate !
super().__init__()
self.save_hyperparameters()
self.model = MyModule(self.input_dim, self.hparam) # nn.Module
self.readout = nn.Linear(..., self.readout_dim) |
@QuentinSoubeyranAqemia thanks for expanding on the use case, it all makes sense now. I have removed the warning since you are right that it does not make sense |
What does this PR do?
Fixes #21036
Skip attributes where user have set `init=False
Before submitting
PR review
Anyone in the community is welcome to review the PR.
Before you start reviewing, make sure you have read the review guidelines. In short, see the following bullet-list:
Reviewer checklist
📚 Documentation preview 📚: https://pytorch-lightning--21051.org.readthedocs.build/en/21051/